import json
import os
import sys
from typing import Callable
from collections import defaultdict
import torch
import torch.nn.utils.prune as prune
import torchmetrics
import torchvision
import wandb
from scipy import stats
from torch.nn.utils.prune import _compute_nparams_toprune, _validate_pruning_amount, _validate_pruning_amount_init


class Utilities:
    """Class of utility functions"""

    @staticmethod
    @torch.no_grad()
    def get_model_norm_square(paramList):
        """Get L2 norm squared of parameter vector. This works for a pruned model as well. The paramList must be in
        the format (module, paramType)"""
        squared_norm = 0.
        for module, paramType in paramList:
            param = getattr(module, paramType)
            squared_norm += torch.norm(param, p=2) ** 2
        return float(squared_norm)

    @staticmethod
    def dump_dict_to_json_wandb(dumpDict, name):
        """Dump some dict to json and upload it"""
        fPath = os.path.join(wandb.run.dir, f'{name}.json')
        with open(fPath, 'w') as fp:
            json.dump(dumpDict, fp)
        wandb.save(fPath)

    @staticmethod
    def get_overloaded_dataset(OriginalDataset):
        class AlteredDatasetWrapper(OriginalDataset):

            def __init__(self, *args, **kwargs):

                super(AlteredDatasetWrapper, self).__init__(*args, **kwargs)

            def __getitem__(self, index):
                # Overload this to collect the class indices once in a vector, which can then be used in the sampler
                image, label = super(AlteredDatasetWrapper, self).__getitem__(index=index)
                return image, label, index

        AlteredDatasetWrapper.__name__ = OriginalDataset.__name__
        return AlteredDatasetWrapper

    @staticmethod
    @torch.no_grad()
    def entropy(t: torch.Tensor):
        """Computes entropy of tensor t, which is expected to have values in [0,1]"""
        min_positive = torch.finfo(t.dtype).tiny
        probs = torch.clamp(t, min=min_positive)
        p_log_p = probs * torch.log(probs)
        entropy = -p_log_p.sum()
        return entropy

    @staticmethod
    @torch.no_grad()
    def replace_layers_by_decomposition(model, svdSparsity, selection):
        """Replaces all convolutional modules by their truncated SVD"""
        assert 0 <= svdSparsity <= 1
        if selection == 'uniform':
            sys.stdout.write("\nPerforming uniform singular value selection.")
            mask = None
        elif selection == 'global':
            sys.stdout.write("\nPerforming global singular value selection.")

            svdValList = []
            # Get the singular values of each layer and concatenate
            for motherName, mother in list(model.named_modules()):
                for childName, child in list(mother.named_children()):
                    if isinstance(child, torch.nn.Conv2d) and not isinstance(child, DecomposedConv2d):
                        child_matrix = getattr(child, 'weight').flatten(start_dim=1)
                        child_svdvals = torch.linalg.svdvals(child_matrix)
                        svdValList.append(child_svdvals.flatten())
            lengthList = [svdv.numel() for svdv in svdValList]
            svdValVec = torch.cat(svdValList)
            k_singvals_to_keep = max(1, int((1 - svdSparsity) * svdValVec.numel()))  # Number of singular values to keep
            maxIndices = torch.topk(svdValVec, k=k_singvals_to_keep, largest=True).indices
            mask = torch.zeros_like(svdValVec)
            mask[maxIndices] = 1

        layerCtr = 0
        svdVals_seen = 0
        for motherName, mother in list(model.named_modules()):
            for childName, child in list(mother.named_children()):
                if isinstance(child, torch.nn.Conv2d) and not isinstance(child, DecomposedConv2d):
                    child_device = getattr(child, 'weight').device
                    if mask is None:
                        s = svdSparsity
                    else:
                        partialMask = mask[svdVals_seen:svdVals_seen + lengthList[layerCtr]]
                        s = torch.sum(partialMask == 0)/partialMask.numel()
                        svdVals_seen += lengthList[layerCtr]
                        layerCtr += 1
                    new_child = DecomposedConv2d(originalLayer=child, svdSparsity=s).to(child_device)
                    setattr(mother, childName, new_child)




    @staticmethod
    @torch.jit.script
    @torch.no_grad()
    def SVD_eigval(W):
        """Computes the first Singular value by LOBPCG of W.T*W"""
        assert len(W.shape) == 2, "W must be a matrix"

        spdMatrix = torch.mm(W.t(), W)
        eVal, eVec = torch.lobpcg(A=spdMatrix, k=1)
        sigma = torch.sqrt(eVal) # singular value
        v = eVec    # right singular vector
        u = 1./sigma * torch.mm(W, v)
        return u.squeeze(), v.squeeze(), sigma

    @staticmethod
    @torch.jit.script
    @torch.no_grad()
    def SVD_power_iteration(W):
        """Requires the gap assumption, i.e. the first Eval is separated from the others by a sufficiently large gap.
        Implementation mostly taken from Pytorch Spectral, explanation why this works can be found here:
        Simple Algorithms for the Partial Singular Value Decomposition (J.C. Nash)"""
        assert len(W.shape) == 2, "W must be a matrix"
        n_power_iterations = 20
        n, m = W.shape

        # Initialize random vectors u and v
        u = torch.nn.functional.normalize(W.new_empty(n).normal_(0., 1.), dim=0)
        v = torch.nn.functional.normalize(W.new_empty(m).normal_(0., 1.), dim=0)


        for _ in range(n_power_iterations):
            # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
            # are the first left and right singular vectors.
            # This power iteration produces approximations of `u` and `v`.
            v = torch.nn.functional.normalize(torch.mv(W.t(), u), dim=0, out=v)
            u = torch.nn.functional.normalize(torch.mv(W, v), dim=0, out=u)

        sigma = torch.dot(u, torch.mv(W, v))
        return u, v, sigma


    @staticmethod
    @torch.no_grad()
    def SVD_block_power_iteration(W, k=1):
        """"""
        assert len(W.shape) == 2, "W must be a matrix"
        assert k <= min(W.shape)
        n_power_iterations = 10
        n, m = W.shape

        # Initialize block vectors
        V = torch.nn.functional.normalize(W.new_empty(m, k).normal_(0., 1.), dim=1)
        for _ in range(n_power_iterations):
            Q, R = torch.linalg.qr(torch.mm(W, V))
            U = Q
            Q, R = torch.linalg.qr(torch.mm(W.t(), U))
            V = Q
        sigma = torch.diag(R)

        # It might be that the singular values have the wrong sign
        U = torch.sign(sigma)*U
        sigma = torch.abs(sigma)
        return U, V.t(), sigma

    @staticmethod
    @torch.no_grad()
    def SVD_partial(W, k=1):
        """Computes the first k singular value/vector pairs. Note: Right now this needs to compute the full SVD, hence
        yielding no speedup for small k."""
        assert len(W.shape) == 2, "W must be a matrix"
        assert k <= min(W.shape)

        # Compute full SVD
        U, S, V_t = torch.linalg.svd(W, full_matrices=False)
        n_singvals_to_keep = k  # Number of singular values to keep

        # Truncate matrices
        V_t = V_t[:n_singvals_to_keep, :]  # Keep only the first rows
        S = S[:n_singvals_to_keep]
        U = U[:, :n_singvals_to_keep]  # Keep only the first columns

        return U, V_t, S


class WorstClassAccuracy(torchmetrics.Accuracy):
    def __init__(self, k: int, **kwargs):
        super().__init__(average=None, **kwargs)
        self.k = k
        assert 1 <= k <= self.num_classes

    def compute(self):
        class_accuracies = super().compute()
        if self.k == 1:
            return class_accuracies.min()
        else:
            min_k = torch.topk(input=class_accuracies, k=self.k, largest=False).values
            return min_k.mean()

class ClassAccuracyRange(torchmetrics.Accuracy):
    def __init__(self, **kwargs):
        super().__init__(average=None, **kwargs)

    def compute(self):
        class_accuracies = super().compute()
        return class_accuracies.max() - class_accuracies.min()


class LAMPUnstructured(prune.BasePruningMethod):
    r"""Prune (currently unpruned) units in a tensor by zeroing out the ones
    with the appropriate LAMP-Score.

    Args:
        amount (int or float): quantity of parameters to prune.
            If ``float``, should be between 0.0 and 1.0 and represent the
            fraction of parameters to prune. If ``int``, it represents the
            absolute number of parameters to prune.
    """

    PRUNING_TYPE = "unstructured"

    def __init__(self, parameters_to_prune, amount):
        # Check range of validity of pruning amount
        self.parameters_to_prune = parameters_to_prune  # This is a (non-sliced) vector that is passed implicitly
        _validate_pruning_amount_init(amount)
        self.amount = amount

    def compute_mask(self, t, default_mask):
        # In the global case, t is already the global parameter vector, same for the mask
        # In the multiple pruning case, we only get the slice, hence we have to do LAMP as if t was all parameters as a vector
        # BUT: For LAMP it is important to distinguish layers

        # Check that the amount of units to prune is not > than the number of
        # parameters in t
        tensor_size = t.nelement()
        # Compute number of units to prune: amount if int,
        # else amount * tensor_size
        nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size)
        # This should raise an error if the number of units to prune is larger
        # than the number of units in the tensor
        _validate_pruning_amount(nparams_toprune, tensor_size)

        tensor_list = []
        length_done = 0
        # Modified from https://github.com/jaeho-lee/layer-adaptive-sparsity
        for module, param_type in self.parameters_to_prune:
            if prune.is_pruned(module):
                p_mask = getattr(module, param_type + '_mask')
                mask_length = int((p_mask == 1).sum())  # Get the number of entries that are still pruneable
            else:
                p_base = getattr(module, param_type)
                mask_length = int(p_base.numel())
            p = t[length_done:length_done + mask_length]
            assert p.numel() == mask_length
            length_done += mask_length

            sorted_scores, sorted_indices = torch.sort(torch.pow(p.flatten(), 2),
                                                       descending=False)  # Get indices to ascending sort
            scores_cumsum_temp = sorted_scores.cumsum(dim=0)
            scores_cumsum = torch.zeros(scores_cumsum_temp.shape, device=p.device)
            scores_cumsum[1:] = scores_cumsum_temp[:len(scores_cumsum_temp) - 1]

            # normalize by cumulative sum
            sorted_scores /= (sorted_scores.sum() - scores_cumsum)
            # tidy up and output
            final_scores = torch.zeros(scores_cumsum.shape, device=p.device)
            final_scores[sorted_indices] = sorted_scores
            tensor_list.append(final_scores)
        score_tensor = torch.cat(tensor_list)
        assert score_tensor.numel() == t.numel()
        mask = default_mask.clone(memory_format=torch.contiguous_format)

        if nparams_toprune != 0:  # k=0 not supported by torch.kthvalue
            # largest=True --> top k; largest=False --> bottom k
            # Prune the smallest k
            topk = torch.topk(
                score_tensor.view(-1), k=nparams_toprune, largest=False
            )
            # topk will have .indices and .values
            mask.view(-1)[topk.indices] = 0

        return mask

    @classmethod
    def apply(cls, module, name, amount):
        r"""Adds the forward pre-hook that enables pruning on the fly and
        the reparametrization of a tensor in terms of the original tensor
        and the pruning mask.

        Args:
            module (nn.Module): module containing the tensor to prune
            name (str): parameter name within ``module`` on which pruning
                will act.
            amount (int or float): quantity of parameters to prune.
                If ``float``, should be between 0.0 and 1.0 and represent the
                fraction of parameters to prune. If ``int``, it represents the
                absolute number of parameters to prune.
        """
        return super(LAMPUnstructured, cls).apply(module, name, amount=amount)

class GradientUnstructured(prune.BasePruningMethod):

    PRUNING_TYPE = "unstructured"

    def __init__(self, parameters_to_prune, amount, gradients, uniform):
        # Check range of validity of pruning amount
        self.parameters_to_prune = parameters_to_prune  # This is a (non-sliced) vector that is passed implicitly
        _validate_pruning_amount_init(amount)
        self.amount = amount
        self.gradients = gradients
        self.uniform = uniform  # If True, apply the pruning uniformly

    def compute_mask(self, t, default_mask):
        # In the global case, t is already the global parameter vector, same for the mask
        # In the multiple pruning case, we only get the slice, hence we have to do LAMP as if t was all parameters as a vector
        # BUT: For LAMP it is important to distinguish layers

        # Check that the amount of units to prune is not > than the number of
        # parameters in t
        tensor_size = t.nelement()
        # Compute number of units to prune: amount if int,
        # else amount * tensor_size
        nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size)
        # This should raise an error if the number of units to prune is larger
        # than the number of units in the tensor
        _validate_pruning_amount(nparams_toprune, tensor_size)

        tensor_list = []
        length_done = 0

        for module, param_type in self.parameters_to_prune:
            if prune.is_pruned(module):
                p_mask = getattr(module, param_type + '_mask')
                mask_length = int((p_mask == 1).sum())  # Get the number of entries that are still pruneable
                d_p = self.gradients[(module, param_type)][p_mask.flatten() == 1]    # Get the corresponding gradient entries
            else:
                p_base = getattr(module, param_type)
                mask_length = int(p_base.numel())
                d_p = self.gradients[(module, param_type)]  # Get the corresponding gradient entries
            p = t[length_done:length_done + mask_length]
            assert p.numel() == mask_length
            assert d_p.numel() == p.numel()
            length_done += mask_length

            final_scores = torch.abs(-p*d_p)

            tensor_list.append(final_scores)
        score_tensor = torch.cat(tensor_list)
        assert score_tensor.numel() == t.numel()
        mask = default_mask.clone(memory_format=torch.contiguous_format)
        if not self.uniform:
            # Select globally
            if nparams_toprune != 0:  # k=0 not supported by torch.kthvalue
                # largest=True --> top k; largest=False --> bottom k
                # Prune the smallest k
                topk = torch.topk(
                    score_tensor.view(-1), k=nparams_toprune, largest=False
                )
                # topk will have .indices and .values
                mask.view(-1)[topk.indices] = 0
        else:
            # Select per layer
            intermediate_mask_list = []
            for idx in range(len(self.parameters_to_prune)):
                local_score_tensor = tensor_list[idx]
                k = _compute_nparams_toprune(self.amount, local_score_tensor.numel())
                if k != 0:
                    # Prune the smallest k
                    topk = torch.topk(
                        local_score_tensor.view(-1), k=k, largest=False
                    )
                    intermediate_mask = torch.zeros_like(local_score_tensor)
                    intermediate_mask.view(-1)[topk.indices] = 1    # Each 1 here will be a zero in the final mask
                    intermediate_mask_list.append(intermediate_mask)
            final_inverted_mask = torch.cat(intermediate_mask_list)
            mask.view(-1)[final_inverted_mask == 1] = 0

        return mask

    @classmethod
    def apply(cls, module, name, amount, parameters_to_prune, gradients):
        r"""Adds the forward pre-hook that enables pruning on the fly and
        the reparametrization of a tensor in terms of the original tensor
        and the pruning mask.

        Args:
            module (nn.Module): module containing the tensor to prune
            name (str): parameter name within ``module`` on which pruning
                will act.
            amount (int or float): quantity of parameters to prune.
                If ``float``, should be between 0.0 and 1.0 and represent the
                fraction of parameters to prune. If ``int``, it represents the
                absolute number of parameters to prune.
        """
        return super(GradientUnstructured, cls).apply(module, name, amount=amount, parameters_to_prune=parameters_to_prune, gradients=gradients)

class UndecayedUnstructured(prune.BasePruningMethod):

    PRUNING_TYPE = "unstructured"

    def __init__(self, parameters_to_prune, amount, gradients, uniform, wd):
        # Check range of validity of pruning amount
        self.parameters_to_prune = parameters_to_prune  # This is a (non-sliced) vector that is passed implicitly
        _validate_pruning_amount_init(amount)
        self.amount = amount
        self.gradients = gradients
        self.uniform = uniform  # If True, apply the pruning uniformly
        self.wd = wd or 0.

    def compute_mask(self, t, default_mask):
        # In the global case, t is already the global parameter vector, same for the mask
        # In the multiple pruning case, we only get the slice, hence we have to do LAMP as if t was all parameters as a vector
        # BUT: For LAMP it is important to distinguish layers

        # Check that the amount of units to prune is not > than the number of
        # parameters in t
        tensor_size = t.nelement()
        # Compute number of units to prune: amount if int,
        # else amount * tensor_size
        nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size)
        # This should raise an error if the number of units to prune is larger
        # than the number of units in the tensor
        _validate_pruning_amount(nparams_toprune, tensor_size)

        tensor_list = []
        length_done = 0

        for module, param_type in self.parameters_to_prune:
            if prune.is_pruned(module):
                p_mask = getattr(module, param_type + '_mask')
                mask_length = int((p_mask == 1).sum())  # Get the number of entries that are still pruneable
                d_p = self.gradients[(module, param_type)][p_mask.flatten() == 1]    # Get the corresponding gradient entries
            else:
                p_base = getattr(module, param_type)
                mask_length = int(p_base.numel())
                d_p = self.gradients[(module, param_type)]  # Get the corresponding gradient entries
            p = t[length_done:length_done + mask_length]
            assert p.numel() == mask_length
            assert d_p.numel() == p.numel()
            length_done += mask_length

            final_scores = torch.abs(-p*d_p + self.wd*(p**2))

            tensor_list.append(final_scores)
        score_tensor = torch.cat(tensor_list)
        assert score_tensor.numel() == t.numel()
        mask = default_mask.clone(memory_format=torch.contiguous_format)
        if not self.uniform:
            # Select globally
            if nparams_toprune != 0:  # k=0 not supported by torch.kthvalue
                # largest=True --> top k; largest=False --> bottom k
                # Prune the smallest k
                topk = torch.topk(
                    score_tensor.view(-1), k=nparams_toprune, largest=False
                )
                # topk will have .indices and .values
                mask.view(-1)[topk.indices] = 0
        else:
            # Select per layer
            intermediate_mask_list = []
            for idx in range(len(self.parameters_to_prune)):
                local_score_tensor = tensor_list[idx]
                k = _compute_nparams_toprune(self.amount, local_score_tensor.numel())
                if k != 0:
                    # Prune the smallest k
                    topk = torch.topk(
                        local_score_tensor.view(-1), k=k, largest=False
                    )
                    intermediate_mask = torch.zeros_like(local_score_tensor)
                    intermediate_mask.view(-1)[topk.indices] = 1    # Each 1 here will be a zero in the final mask
                    intermediate_mask_list.append(intermediate_mask)
            final_inverted_mask = torch.cat(intermediate_mask_list)
            mask.view(-1)[final_inverted_mask == 1] = 0

        return mask

    @classmethod
    def apply(cls, module, name, amount, parameters_to_prune, gradients, wd):
        r"""Adds the forward pre-hook that enables pruning on the fly and
        the reparametrization of a tensor in terms of the original tensor
        and the pruning mask.

        Args:
            module (nn.Module): module containing the tensor to prune
            name (str): parameter name within ``module`` on which pruning
                will act.
            amount (int or float): quantity of parameters to prune.
                If ``float``, should be between 0.0 and 1.0 and represent the
                fraction of parameters to prune. If ``int``, it represents the
                absolute number of parameters to prune.
        """
        return super(UndecayedUnstructured, cls).apply(module, name, amount=amount, parameters_to_prune=parameters_to_prune, gradients=gradients, wd=wd)


class FairnessStatistics:
    """Keeps track of all fairness related statistics.

    General statistics used:
    - CIE: Compression impacted examples
    - CIE-P: CIE that were correctly classified by the dense model
    - CIE-N: CIE that were incorrectly classified by the dense model, but are correctly by the compressed (Note: CIEP + CIEN != CIE)

    Per-class statistics used:
    - CIE()_rel: fraction of CIE/CIEP/CIEN per class relative to the total amount of CIE/CIEP/CIEN
    - CIEP_relpos: fraction of CIEP per class relative to total correctly classified in the dense model (also per class)

    """

    def __init__(self, n_classes, device):
        self.n_classes = n_classes
        self.device = device

        self.loss_fn = torch.nn.CrossEntropyLoss(reduction='none').to(device=self.device)

        self.classes = torch.arange(0, self.n_classes, 1).to(device=self.device)

        # General class count
        self.class_occurence = torch.zeros_like(self.classes)
        self.dense_correct_occurence = torch.zeros_like(self.classes)
        self.correct_occurence = torch.zeros_like(self.classes)
        self.confusion = {state: {} for state in ['dense', 'sparse']}
        for state in self.confusion.keys():
            for statType in ['TP', 'FN', 'FP', 'TN']:
                self.confusion[state][statType] = torch.zeros_like(self.classes)
        # Compression Impacted Examples
        self.CIE = torch.zeros_like(self.classes)
        self.CIEP = torch.zeros_like(self.classes)
        self.CIEN = torch.zeros_like(self.classes)

        # Dense statistics
        self.dense_loss_per_class = torch.zeros(self.n_classes, device=self.device)
        self.dense_confidence_per_class = torch.zeros(self.n_classes, device=self.device)

    @torch.no_grad()
    def __call__(self, output, output_dense, y_true):
        """
        Gets called for each iteration during a single evaluation pass

        output: Model output of the compressed model
        output_dense: Model output of the dense model
        y_true: true labels
        """

        # General statistics
        occ, cnt = torch.unique(y_true, return_counts=True)
        self.class_occurence[occ] += cnt
        self.dense_correct_occurence += self.get_class_correct(output=output_dense, y_true=y_true)
        self.correct_occurence += self.get_class_correct(output=output, y_true=y_true)
        for state in self.confusion.keys():
            outputType = output if state == 'sparse' else output_dense
            TP, FN, FP, TN = self.get_confusion_matrix(output=outputType, y_true=y_true)
            self.confusion[state]['TP'] += TP
            self.confusion[state]['FN'] += FN
            self.confusion[state]['FP'] += FP
            self.confusion[state]['TN'] += TN

        # Compression Impacted Examples
        CIE, CIEP, CIEN = self.get_CIE(output=output, output_dense=output_dense, y_true=y_true)
        self.CIE += CIE
        self.CIEP += CIEP
        self.CIEN += CIEN

        # Dense statistics
        self.add_dense_class_loss(output_dense=output_dense, y_true=y_true)
        self.add_dense_class_confidence(output_dense=output_dense, y_true=y_true)

    @torch.no_grad()
    def get_results(self):
        logDict = {'n_CIE': self.CIE.sum(),
                   'n_CIEP': self.CIEP.sum(),
                   'n_CIEN': self.CIEN.sum(),
                   'class': {cls: dict() for cls in range(self.n_classes)}
                   }

        # Compression Impact Metrics
        for CIE_type in ['CIE', 'CIEP', 'CIEN']:
            CIE_absolute = getattr(self, CIE_type)
            # Get correct denominator for relative values, e.g. for CIEP we want to know the fraction of CIEP(y) versus total dense correct ones in class
            if CIE_type == 'CIE':
                denom = self.class_occurence
            elif CIE_type == 'CIEP':
                denom = self.dense_correct_occurence
            elif CIE_type == 'CIEN':
                denom = self.class_occurence - self.dense_correct_occurence

            CIE_relative = getattr(self, CIE_type) / denom
            for cls in range(self.n_classes):
                logDict['class'][cls][CIE_type + "_abs"] = CIE_absolute[cls].item()
                logDict['class'][cls][CIE_type + "_rel"] = CIE_relative[cls].item()

        # General metrics
        distribution = self.class_occurence / self.class_occurence.sum()

        # Recall
        dense_class_recall = self.dense_correct_occurence / self.class_occurence
        class_recall = self.correct_occurence / self.class_occurence

        # Accuracy
        dense_class_accuracy = dense_class_recall * distribution
        class_accuracy = class_recall * distribution

        # Differences: Recall and Accuracy
        diffDict = dict()
        for diffType in ['recall', 'accuracy']:
            dense_values = dense_class_recall if diffType == 'recall' else dense_class_accuracy
            values = class_recall if diffType == 'recall' else class_accuracy

            diffDict[f'dense_{diffType}'] = dense_values
            diffDict[diffType] = values

            diffDict[f'abs_dense_class_{diffType}_deviation_from_mean'] = dense_values - dense_values.mean()
            diffDict[f'rel_dense_class_{diffType}_deviation_from_mean'] = diffDict[
                                                                              f'abs_dense_class_{diffType}_deviation_from_mean'] / dense_values.mean()

            diffDict[f'abs_{diffType}_diff'] = values - dense_values
            diffDict[f'rel_{diffType}_diff'] = (values - dense_values) / dense_values

            diffDict[f'negative_{diffType}_change'] = torch.clip(dense_values - values, min=0)


        cls_cond_risk = self.dense_loss_per_class * distribution
        cls_cond_risk_rel = cls_cond_risk / cls_cond_risk.sum()
        cls_pred_confidence = self.dense_confidence_per_class / self.class_occurence

        for cls in range(self.n_classes):
            # Class conditioned risk
            logDict['class'][cls]["dense_cond_risk"] = cls_cond_risk[cls].item()
            logDict['class'][cls]["dense_cond_risk_rel"] = cls_cond_risk_rel[cls].item()

            # Distribution
            logDict['class'][cls]["distribution"] = distribution[cls].item()

            # Class prediction confidence
            logDict['class'][cls]["pred_conf"] = cls_pred_confidence[cls].item()

            # Class recall and accuracy
            for name, tensor in diffDict.items():
                logDict['class'][cls][name] = tensor[cls].item()

        # Balanced/Overall accuracy
        logDict['dense_balanced_accuracy'] = dense_class_recall.mean().item()
        logDict['balanced_accuracy'] = class_recall.mean().item()
        logDict['dense_overall_accuracy'] = dense_class_accuracy.sum().item()
        logDict['overall_accuracy'] = class_accuracy.sum().item()

        ### Fairness metrics
        # Unfairness due to Joseph et al.
        dense_unfairness = dense_class_recall.max() - dense_class_recall.min()
        unfairness = class_recall.max() - class_recall.min()
        logDict['dense_unfairness'] = dense_unfairness
        logDict['unfairness'] = unfairness
        logDict['unfairness_change_ratio'] = unfairness / dense_unfairness

        # CVE
        FPR = self.confusion['sparse']['FP'] / (self.confusion['sparse']['FP'] + self.confusion['sparse']['TN'])
        FNR = self.confusion['sparse']['FN'] / (self.confusion['sparse']['FN'] + self.confusion['sparse']['TP'])
        FPR_dense = self.confusion['dense']['FP'] / (self.confusion['dense']['FP'] + self.confusion['dense']['TN'])
        FNR_dense = self.confusion['dense']['FN'] / (self.confusion['dense']['FN'] + self.confusion['dense']['TP'])
        delta_FPR, delta_FNR = (FPR - FPR_dense) / FPR_dense, (FNR - FNR_dense) / FNR_dense
        cve = torch.var(torch.cat((delta_FPR, delta_FNR)))
        logDict['CVE'] = cve

        # New metrics
        for metricType in ['abs', 'rel']:
            neg_recall_change = diffDict[f'negative_recall_change'] / dense_class_recall if metricType == 'rel' else \
            diffDict[f'negative_recall_change']
            neg_recall_change_mean, neg_recall_change_std = neg_recall_change.mean(), neg_recall_change.std()
            logDict[f'{metricType}_neg_recall_change_mean'] = neg_recall_change_mean
            logDict[f'{metricType}_neg_recall_change_std'] = neg_recall_change_std
            logDict[f'{metricType}_neg_recall_change_range'] = neg_recall_change.max() - neg_recall_change_std.min()
            logDict[f'{metricType}_neg_recall_change_max'] = neg_recall_change.max()
            logDict[f'{metricType}_neg_recall_change_min'] = neg_recall_change.min()
            logDict[f'{metricType}_neg_recall_change_quartilerange'] = torch.quantile(neg_recall_change,
                                                                                      q=0.75) - torch.quantile(
                neg_recall_change, q=0.25)

        """
        for diffType in ['recall', 'accuracy']:

            # Neg/Pos contribution
            logDict[f'{diffType}_diff_pos_contribution'] = torch.sum(diffDict[f'abs_{diffType}_diff'][diffDict[f'abs_{diffType}_diff'] >= 0]).item()
            logDict[f'{diffType}_diff_neg_contribution'] = torch.sum(
                diffDict[f'abs_{diffType}_diff'][diffDict[f'abs_{diffType}_diff'] < 0]).item()

            # Further Relative-Class-Accuracy/Recall metrics
            logDict[f'abs_{diffType}_diff_neg_range'] = diffDict[f'abs_{diffType}_diff'][diffDict[f'abs_{diffType}_diff'] < 0].max() - diffDict[f'abs_{diffType}_diff'][diffDict[f'abs_{diffType}_diff'] < 0].min() if diffDict[f'abs_{diffType}_diff'][diffDict[f'abs_{diffType}_diff'] < 0].numel() > 0 else 0
            logDict[f'abs_{diffType}_diff_pos_range'] = diffDict[f'abs_{diffType}_diff'][diffDict[f'abs_{diffType}_diff'] >= 0].max() - diffDict[f'abs_{diffType}_diff'][
                diffDict[f'abs_{diffType}_diff'] >= 0].min() if diffDict[f'abs_{diffType}_diff'][diffDict[f'abs_{diffType}_diff'] >= 0].numel() > 0 else 0
        """
        return {k: v for k, v in logDict.items() if k != 'class'}, {'class': logDict['class']}

    @torch.no_grad()
    def add_dense_class_confidence(self, output_dense, y_true):
        """Adds the dense confidence per class"""
        probs = torch.nn.functional.softmax(output_dense, dim=1)  # dim: bs x n_classes
        confidence_per_sample = torch.max(probs,
                                          dim=1).values  # Note that this is not the confidence in the right class, but rather in the prediction
        self.dense_confidence_per_class.scatter_add_(0, y_true.flatten(), confidence_per_sample.flatten())

    @torch.no_grad()
    def add_dense_class_loss(self, output_dense, y_true):
        """Adds the dense loss per class"""
        loss_per_sample = self.loss_fn(output_dense, y_true)  # dim: batch_size
        self.dense_loss_per_class.scatter_add_(0, y_true.flatten(), loss_per_sample.flatten())

    @torch.no_grad()
    def get_class_correct(self, output, y_true):
        """Returns the amount of correct predictions per class in one batch"""
        prediction = output.max(dim=1).indices.t()
        cls_bool_mask = (y_true.unsqueeze(0).expand(self.n_classes, len(y_true)) == self.classes.unsqueeze(1))
        correct_dense_bool_mask = prediction.eq(y_true).unsqueeze(0).expand(self.n_classes, len(y_true))
        return torch.logical_and(cls_bool_mask, correct_dense_bool_mask).sum(dim=1)

    @torch.no_grad()
    def get_confusion_matrix(self, output, y_true):
        """Returns TP, FP, FN, TN"""
        # Compute class occurence
        occ, cnt = torch.unique(y_true, return_counts=True)
        n_cls_occurences = torch.zeros_like(self.classes)
        n_cls_occurences[occ] += cnt

        # Get prediction of network
        prediction = output.max(dim=1).indices.t()

        # Get total population tensor (P+N)
        total_population = torch.zeros_like(self.classes).fill_(len(y_true))

        # Booltensor with n_classes rows indicating at which element in y_true the corresponding class occurs
        true_bool_mask = (y_true.unsqueeze(0).expand(self.n_classes, len(y_true)) == self.classes.unsqueeze(1))
        # Booltensor with n_classes rows indicating at which element in prediction the corresponding class occurs
        pred_bool_mask = (prediction.unsqueeze(0).expand(self.n_classes, len(y_true)) == self.classes.unsqueeze(1))
        # Booltensor with n_classes rows indicating at which element the prediction is true
        correct_bool_mask = prediction.eq(y_true).unsqueeze(0).expand(self.n_classes, len(y_true))

        TP = torch.logical_and(true_bool_mask, correct_bool_mask).sum(dim=1)  # TP: Class and Prediction coincide
        FN = n_cls_occurences - TP  # FN: Class occurs but prediction is not the same
        FP = torch.logical_and(pred_bool_mask, ~correct_bool_mask).sum(
            dim=1)  # FP: Class does not occur but is predicted
        TN = total_population - n_cls_occurences - FP  # TN: Class does not occur and is correctly predicted as not occurring
        return TP, FN, FP, TN

    @torch.no_grad()
    def get_CIE(self, output, output_dense, y_true):
        prediction = output.max(dim=1).indices.t()
        prediction_dense = output_dense.max(dim=1).indices.t()

        cls_bool_mask = (y_true.unsqueeze(0).expand(self.n_classes, len(y_true)) == self.classes.unsqueeze(1))
        CIE_bool = torch.logical_and((~prediction.eq(prediction_dense)).unsqueeze(1).expand(-1, self.n_classes).t(),
                                     cls_bool_mask)
        correct_dense_bool_mask = prediction_dense.eq(y_true).unsqueeze(0).expand(self.n_classes, len(y_true))
        correct_sparse_bool_mask = prediction.eq(y_true).unsqueeze(0).expand(self.n_classes, len(y_true))

        CIE = torch.logical_and((~prediction.eq(prediction_dense)).unsqueeze(1).expand(-1, self.n_classes).t(),
                                cls_bool_mask).sum(dim=1)
        CIEP = torch.logical_and(CIE_bool, correct_dense_bool_mask).sum(dim=1)
        CIEN = torch.logical_and(CIE_bool, correct_sparse_bool_mask).sum(dim=1)
        return CIE, CIEP, CIEN

class DecomposedConv2d(torch.nn.Module):
    """Converts a single Conv2d layer into two Conv2d layers by exploiting truncated SVD. The idea is
     to view the 4d-convolution as a matrix by flattening in_channels, and (k,k)."""

    def __init__(self, originalLayer, svdSparsity):
        super().__init__()
        assert isinstance(originalLayer, torch.nn.Conv2d)
        assert 0 <= svdSparsity <= 1
        self.compute_decomposition(originalLayer=originalLayer, svdSparsity=svdSparsity)

    @torch.no_grad()
    def compute_decomposition(self, originalLayer, svdSparsity):

        # Save information of originalLayer
        W = originalLayer.weight
        b = originalLayer.bias  # None if not existing
        c_out, c_in, k1, k2 = W.shape
        stride, padding = originalLayer.stride, originalLayer.padding
        n, m = W.flatten(start_dim=1).shape  # Maps from R^m to R^n

        # Compute SVD
        U, S, V_t = torch.linalg.svd(W.flatten(start_dim=1), full_matrices=False)
        n_singvals_to_keep = max(1, int((1 - svdSparsity) * S.shape[0]))  # Number of singular values to keep

        # Truncate matrices
        V_t = V_t[:n_singvals_to_keep, :]  # Keep only the first rows
        S = S[:n_singvals_to_keep]
        U = U[:, :n_singvals_to_keep]  # Keep only the first columns

        # Initialize two new layers of shape
        self.first = torch.nn.Conv2d(in_channels=c_in, out_channels=n_singvals_to_keep,
                                     kernel_size=(k1, k2), stride=stride, padding=padding, bias=False)
        self.second = torch.nn.Conv2d(in_channels=n_singvals_to_keep, out_channels=c_out,
                                     kernel_size=(1, 1), bias=(b is not None))


        # Set the parameters as computed in the SVD
        self.first.weight.copy_(torch.matmul(torch.diag(S), V_t).view(self.first.weight.shape))
        self.second.weight.copy_(U.view(self.second.weight.shape))
        if b is not None:
            self.second.bias.copy_(b)

    def forward(self, x):
        out = self.first(x)
        out = self.second(out)
        return out